import numpy as np
import copy
import torch
from torch.autograd import Variable
import time
from .ufedbase import UnlearnBasicClient, UnlearnBasicServer
import numpy as np
from utils import fmodule
from tqdm import tqdm
class Server(UnlearnBasicServer):
    def __init__(self, option, model, clients, data_loader, device=None):
        super(Server, self).__init__(option, model, clients, data_loader, device)

    def run(self):
        self.current_rounds = 0
        test_metric = self.test_on_clients(dataflag='test', model=self.model)
        self.outFunc(t_metric=test_metric)

        for round in tqdm(range(1, self.num_rounds + 1), desc='Pretraining Rounds'):
            self.current_rounds = round
            # federated train
            self.iterate()

            # syn
            self.global_lr_scheduler(self.num_rounds)

            test_metric = self.test_on_clients(dataflag='test', model=self.model)
            self.outFunc(test_metric)
            self.save_log(self.out_log)
        self.save_ckp()

    def iterate(self):
        self.selected_clients = np.delete(self.clients_id, np.where(np.isin(self.clients_id, self.unlearn_clients_id))[0]) # TODO: 可能问题在这
        reply = self.communicate(self.selected_clients)
        # 按照self.selected_clients = self.received_clients
        models, losses = reply['model'], reply['loss']

        self.model = self.aggregate(models)
        del models
        return

class Client(UnlearnBasicClient):
    def __init__(self, option, id, model=None):
        super(Client, self).__init__(option, id, model)
